{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import soundfile as sf\n",
    "import torch\n",
    "from matplotlib import pylab as plt\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from TTS.config import load_config\n",
    "from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
    "from TTS.tts.datasets import load_tts_samples\n",
    "from TTS.tts.datasets.dataset import TTSDataset\n",
    "from TTS.tts.layers.losses import L1LossMasked\n",
    "from TTS.tts.models import setup_model\n",
    "from TTS.tts.utils.helpers import sequence_mask\n",
    "from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
    "from TTS.tts.utils.visual import plot_spectrogram\n",
    "from TTS.utils.audio import AudioProcessor\n",
    "from TTS.utils.audio.numpy_transforms import quantize\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "# Configure CUDA visibility\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to create directories and file names\n",
    "def set_filename(wav_path, out_path):\n",
    "    wav_file = os.path.basename(wav_path)\n",
    "    file_name = wav_file.split('.')[0]\n",
    "    os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
    "    os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
    "    wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
    "    mel_path = os.path.join(out_path, \"mel\", file_name)\n",
    "    return file_name, wavq_path, mel_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paths and configurations\n",
    "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
    "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
    "PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
    "DATASET = \"ljspeech\"\n",
    "METADATA_FILE = \"metadata.csv\"\n",
    "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
    "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "QUANTIZE_BITS = 0  # if non-zero, quantize wav files with the given number of bits\n",
    "DRY_RUN = False   # if False, does not generate output files, only computes loss and visuals.\n",
    "\n",
    "# Check CUDA availability\n",
    "use_cuda = torch.cuda.is_available()\n",
    "print(\" > CUDA enabled: \", use_cuda)\n",
    "\n",
    "# Load the configuration\n",
    "dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
    "C = load_config(CONFIG_PATH)\n",
    "C.audio['do_trim_silence'] = False  # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
    "ap = AudioProcessor(**C.audio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the tokenizer\n",
    "tokenizer, C = TTSTokenizer.init_from_config(C)\n",
    "\n",
    "# Load the model\n",
    "# TODO: multiple speakers\n",
    "model = setup_model(C)\n",
    "model.load_checkpoint(C, MODEL_FILE, eval=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data instances\n",
    "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
    "meta_data = meta_data_train + meta_data_eval\n",
    "\n",
    "dataset = TTSDataset(\n",
    "    outputs_per_step=C[\"r\"],\n",
    "    compute_linear_spec=False,\n",
    "    ap=ap,\n",
    "    samples=meta_data,\n",
    "    tokenizer=tokenizer,\n",
    "    phoneme_cache_path=PHONEME_CACHE_PATH,\n",
    ")\n",
    "loader = DataLoader(\n",
    "    dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate model outputs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize lists for storing results\n",
    "file_idxs = []\n",
    "metadata = []\n",
    "losses = []\n",
    "postnet_losses = []\n",
    "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
    "\n",
    "# Start processing with a progress bar\n",
    "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
    "with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
    "    for data in tqdm(loader, desc=\"Processing\"):\n",
    "        try:\n",
    "            # dispatch data to GPU\n",
    "            if use_cuda:\n",
    "                data[\"token_id\"] = data[\"token_id\"].cuda()\n",
    "                data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
    "                data[\"mel\"] = data[\"mel\"].cuda()\n",
    "                data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
    "\n",
    "            mask = sequence_mask(data[\"token_id_lengths\"])\n",
    "            outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
    "            mel_outputs = outputs[\"decoder_outputs\"]\n",
    "            postnet_outputs = outputs[\"model_outputs\"]\n",
    "\n",
    "            # compute loss\n",
    "            loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
    "            loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
    "            losses.append(loss.item())\n",
    "            postnet_losses.append(loss_postnet.item())\n",
    "\n",
    "            # compute mel specs from linear spec if the model is Tacotron\n",
    "            if C.model == \"Tacotron\":\n",
    "                mel_specs = []\n",
    "                postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
    "                for b in range(postnet_outputs.shape[0]):\n",
    "                    postnet_output = postnet_outputs[b]\n",
    "                    mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
    "                postnet_outputs = torch.stack(mel_specs)\n",
    "            elif C.model == \"Tacotron2\":\n",
    "                postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
    "            alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
    "\n",
    "            if not DRY_RUN:\n",
    "                for idx in range(data[\"token_id\"].shape[0]):\n",
    "                    wav_file_path = data[\"item_idxs\"][idx]\n",
    "                    wav = ap.load_wav(wav_file_path)\n",
    "                    file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
    "                    file_idxs.append(file_name)\n",
    "\n",
    "                    # quantize and save wav\n",
    "                    if QUANTIZE_BITS > 0:\n",
    "                        wavq = quantize(wav, QUANTIZE_BITS)\n",
    "                        np.save(wavq_path, wavq)\n",
    "\n",
    "                    # save TTS mel\n",
    "                    mel = postnet_outputs[idx]\n",
    "                    mel_length = data[\"mel_lengths\"][idx]\n",
    "                    mel = mel[:mel_length, :].T\n",
    "                    np.save(mel_path, mel)\n",
    "\n",
    "                    metadata.append([wav_file_path, mel_path])\n",
    "        except Exception as e:\n",
    "            log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
    "\n",
    "    # Calculate and log mean losses\n",
    "    mean_loss = np.mean(losses)\n",
    "    mean_postnet_loss = np.mean(postnet_losses)\n",
    "    log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
    "    log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
    "\n",
    "# For wavernn\n",
    "if not DRY_RUN:\n",
    "    pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
    "\n",
    "# For pwgan\n",
    "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
    "    for wav_file_path, mel_path in metadata:\n",
    "        f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
    "\n",
    "# Print mean losses\n",
    "print(f\"Mean Loss: {mean_loss}\")\n",
    "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sanity Check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 1\n",
    "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
    "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
    "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
    "mel_truth = ap.melspectrogram(wav)\n",
    "print(mel_truth.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot posnet output\n",
    "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
    "plot_spectrogram(mel_postnet, ap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot decoder output\n",
    "print(mel_decoder.shape)\n",
    "plot_spectrogram(mel_decoder, ap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot GT specgrogram\n",
    "print(mel_truth.shape)\n",
    "plot_spectrogram(mel_truth.T, ap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# postnet, decoder diff\n",
    "mel_diff = mel_decoder - mel_postnet\n",
    "plt.figure(figsize=(16, 10))\n",
    "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
    "plt.colorbar()\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PLOT GT SPECTROGRAM diff\n",
    "mel_diff2 = mel_truth.T - mel_decoder\n",
    "plt.figure(figsize=(16, 10))\n",
    "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
    "plt.colorbar()\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PLOT GT SPECTROGRAM diff\n",
    "mel = postnet_outputs[idx]\n",
    "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
    "plt.figure(figsize=(16, 10))\n",
    "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
    "plt.colorbar()\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
  },
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('base': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
